fine_tune/Llama-Vision FT.ipynb (323 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "id": "8e0b5326-74fd-4846-9f53-fbb5b51afdae", "metadata": {}, "source": [ "## Fine-tuning Llama 3.2 Vision using Trainer" ] }, { "cell_type": "markdown", "id": "c6b3112a-1e22-4da0-9713-ded4483bb732", "metadata": {}, "source": [ "Transformers Trainer API makes it easy to fine-tune Llama-Vision models. One can also use parameter-efficient fine-tuning techniques out of the box thanks to transformers integration. Make sure to have latest version of transformers.\n", "\n", "\n", "We will fine-tune the model on a small split of VQAv2 dataset for educational purposes. If you want, you can also use a dataset where there’s multiple turns of conversation at one example. This dataset consists of images, questions about the images and short answers.\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "ee883053-3f16-479e-9d55-49a05e62c400", "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "ds = load_dataset(\"merve/vqav2-small\", split=\"validation[:10%]\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "9ce170d7-4fc5-4b08-b4e3-b884b82c9495", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", " features: ['multiple_choice_answer', 'question', 'image'],\n", " num_rows: 2144\n", "})" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds" ] }, { "cell_type": "markdown", "id": "176d19d0-a55a-4267-b2e3-0fb9e13ec80f", "metadata": {}, "source": [ "We have to authenticate outselves before downloading the model. " ] }, { "cell_type": "code", "execution_count": 5, "id": "7a5fe387-e0ad-4b36-b0de-a231d1a604a9", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "acce430902c94d979584b0de179c23fc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from huggingface_hub import notebook_login\n", "notebook_login()" ] }, { "cell_type": "markdown", "id": "dadba5ec-e4a6-480c-a09f-e4692fad6f79", "metadata": {}, "source": [ "We can now initialize the model and the processor, for we will use the processor in our preprocessing function. We will initialize the 11B variant of the vision model. \n", "\n", "Llama authors encourage freezing text decoder and only training image encoder. If you would like to try this out, feel free to set `FREEZE_LLM` to `True`. Intuitively, if your task is too domain specific, you might want to avoid this. In that case, you can either try LoRA training (which you can set `USE_LORA` to `True`), or freezing image encoder (set `FREEZE_IMAGE` to `True`) to save up compute.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a98466a9", "metadata": {}, "outputs": [], "source": [ "from transformers import MllamaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig\n", "from peft import LoraConfig, get_peft_model\n", "import torch\n", "\n", "ckpt = \"meta-llama/Llama-3.2-11B-Vision\"\n", "USE_LORA = True\n", "FREEZE_LLM = False\n", "FREEZE_IMAGE = False\n", "\n", "if USE_LORA:\n", " lora_config = LoraConfig(\n", " r=8,\n", " lora_alpha=8,\n", " lora_dropout=0.1,\n", " target_modules=['down_proj','o_proj','k_proj','q_proj','gate_proj','up_proj','v_proj'],\n", " use_dora=True, # optional DoRA \n", " init_lora_weights=\"gaussian\"\n", " )\n", "\n", " model = MllamaForConditionalGeneration.from_pretrained(\n", " ckpt,\n", " torch_dtype=torch.bfloat16,\n", " device_map=\"auto\"\n", " )\n", "\n", " model = get_peft_model(model, lora_config)\n", " model.print_trainable_parameters()\n", "\n", "elif FREEZE_IMAGE:\n", " if FREEZE_LLM:\n", " raise ValueError(\"You cannot freeze image encoder and text decoder at the same time.\")\n", " model = MllamaForConditionalGeneration.from_pretrained(ckpt,\n", " torch_dtype=torch.bfloat16, device_map=\"auto\")\n", " # freeze vision model to save up on compute\n", " for param in model.vision_model.parameters():\n", " param.requires_grad = False\n", "\n", "elif FREEZE_LLM:\n", " if FREEZE_IMAGE:\n", " raise ValueError(\"You cannot freeze image encoder and text decoder at the same time.\")\n", " model = MllamaForConditionalGeneration.from_pretrained(ckpt,\n", " torch_dtype=torch.bfloat16, device_map=\"auto\")\n", " # freeze text model, this is encouraged in paper\n", " for param in model.language_model.parameters():\n", " param.requires_grad = False\n", " \n", "else: # full ft\n", " model = MllamaForConditionalGeneration.from_pretrained(ckpt,\n", " torch_dtype=torch.bfloat16, device_map=\"auto\")\n", "\n", "processor = AutoProcessor.from_pretrained(ckpt)" ] }, { "cell_type": "markdown", "id": "72738779-b3b9-424b-92c1-c72b59898543", "metadata": {}, "source": [ "For preprocessing, we will put together questions and answers. In between questions and answers we will put a conditioning phrase, which will condition the model and trigger question answering, in this case it’s “Answer briefly.”. \n", "To process images, we simply have to batch every image and put them as list of singular images. This is needed due to how processor can take a list of multiple images at once with a single text input, so we have to indicate that these are single images for each example.\n", "Lastly, we will set pad tokens and image tokens to -100 to make model ignore these tokens.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "467eeb17-092a-42fd-9ad7-48642cc03731", "metadata": {}, "outputs": [], "source": [ "def process(examples):\n", " texts = [f\"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\\n\\n<|image|>{example['question']} Answer briefly. <|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n{example['multiple_choice_answer']}<|eot_id|>\" for example in examples]\n", " images = [[example[\"image\"].convert(\"RGB\")] for example in examples]\n", "\n", " batch = processor(text=texts, images=images, return_tensors=\"pt\", padding=True)\n", " labels = batch[\"input_ids\"].clone()\n", " labels[labels == processor.tokenizer.pad_token_id] = -100 \n", " labels[labels == 128256] = -100 # image token index\n", " batch[\"labels\"] = labels\n", " batch = batch.to(torch.bfloat16).to(\"cuda\")\n", "\n", " return batch\n" ] }, { "cell_type": "markdown", "id": "971e0923-d0ef-43b7-9db5-5008f78b2782", "metadata": {}, "source": [ "We can now setup our Trainer. Before that, we will setup the arguments we pass to the \n", "Trainer." ] }, { "cell_type": "code", "execution_count": null, "id": "d19d7bab-aa59-4c78-81a1-10724efe97d0", "metadata": {}, "outputs": [], "source": [ "from transformers import TrainingArguments\n", "args=TrainingArguments(\n", " num_train_epochs=2,\n", " remove_unused_columns=False,\n", " per_device_train_batch_size=1,\n", " gradient_accumulation_steps=4,\n", " warmup_steps=2,\n", " learning_rate=2e-5,\n", " weight_decay=1e-6,\n", " adam_beta2=0.999,\n", " logging_steps=250,\n", " save_strategy=\"no\",\n", " optim=\"adamw_hf\",\n", " push_to_hub=True,\n", " save_total_limit=1,\n", " bf16=True,\n", " output_dir=\"./lora\",\n", " dataloader_pin_memory=False,\n", " )" ] }, { "cell_type": "markdown", "id": "3083279f-b084-4e6c-8175-4986cc53ad2a", "metadata": {}, "source": [ "We can now initialize the Trainer and start training.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "af2971d6-318e-4d4a-952e-6a15d511ea3b", "metadata": {}, "outputs": [], "source": [ "from transformers import Trainer\n", "trainer = Trainer(\n", " model=model,\n", " train_dataset=ds,\n", " data_collator=process,\n", " args=args\n", " )" ] }, { "cell_type": "markdown", "id": "8d1d91a6", "metadata": {}, "source": [ "Call train." ] }, { "cell_type": "code", "execution_count": null, "id": "4c5fca52-c078-481f-922e-c22c49ce412a", "metadata": {}, "outputs": [], "source": [ "trainer.train()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }